import random
from tqdm import tqdm
from src.web.app import db


# sql session class
class MySqlView:

    def __init__(self, table):
        # init session
        self.session = db.session
        self.table = table
        self.table_copy = table + '_copy'

    def query_all_rows(self):
        sql = "select count(*) from %s;" % self.table
        sql_copy = "select count(*) from %s;" % self.table_copy
        cursor = self.session.execute(sql)
        cursor_copy = self.session.execute(sql_copy)
        res = cursor.fetchall()
        res_copy = cursor_copy.fetchall()
        print('index. ', res)
        print('copy.  ', res_copy)

    def explain_base_query(self, sql):
        suf_sql = 'explain '
        sql = suf_sql + sql
        cursor = self.session.execute(sql)
        res = cursor.fetchall()
        return res

    # 10,0000 rows sql
    def create_mock_data(self):
        sql = 'insert into %s (a, b, c) values (%s, %s, %s)' % self.table
        n = 10 ** 5
        for _ in tqdm(range(n)):
            a, b, c = random.randint(0, 1000), random.randint(0, 1000), random.randint(0, 1000)
            # print(sql % (a, b, c))
            self.session.execute(sql % (a, b, c))
        # commit session
        self.session.commit()

    # copy data struct and data
    def copy_table(self):
        sql = 'create table %s select * from %s' % (self.table_copy, self.table)
        # copy
        self.session.execute(sql)
        # commit
        self.session.commit()

    # compare index
    def compare_index(self):
        # suf operation
        target_a = '432'
        target_b = '985'
        # common sql
        sql_base = 'select * from %s where a = %s and b = %s'
        # compare sql
        sql = sql_base % (self.table, target_a, target_b)
        sql_copy = sql_base % (self.table_copy, target_a, target_b)
        # result
        print(sql)
        print(sql_copy)
        sql_res = self.explain_base_query(sql)
        sql_copy_res = self.explain_base_query(sql_copy)
        print('add index.', sql_res)
        print('nul index.', sql_copy_res)


if __name__ == '__main__':
    v = MySqlView('abc')
    # v.query_all_rows()
    v.compare_index()